# -*- coding: utf-8 -*-
import torch
from torch.optim.optimizer import Optimizer, required
import numpy as np
from utils import *
import time
import os
import csv
import torch.nn as nn
import argparse
from backend import *
from datasets import *

torch.manual_seed(1)

def is_conv_weight(weight):
    return len(weight.shape) == 4

class AdaHSPG(Optimizer):
    def __init__(self, params, lr=required, epsilon=required, lmbda = required, kappa = required):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))

        if lmbda is not required and lmbda < 0.0:
            raise ValueError("Invalid lambda: {}".format(lmbda))

        if epsilon is not required and epsilon < 0.0:
            raise ValueError("Invalid epsilon: {}".format(epsilon))

        if kappa is not required and kappa < 0.0:
            raise ValueError("Invalid kappa: {}".format(kappa))

        defaults = dict(lr=lr, lmbda=lmbda, epsilon=epsilon, kappa=kappa)
        super(AdaHSPG, self).__init__(params, defaults)


    def __setstate__(self, state):
        super(AdaHSPG, self).__setstate__(state)
    
    def adjust_learning_rate(self, epoch):
        if epoch % 75 == 0 and epoch > 0:
            for group in self.param_groups:
                group['lr'] /= float(10)

    def adjust_epsilon(self, curr_group_sparsity, prev_group_sparsity, norm_gradmap):
        return_eplison = None
        if norm_gradmap <= 5e-1 and prev_group_sparsity == curr_group_sparsity:
            for group in self.param_groups:
                group['epsilon'] = max(group['epsilon'], 0.1) * 2.0
                group['epsilon'] = min(group['epsilon'], 0.999)
                return_eplison = group['epsilon'] 
        elif norm_gradmap > 5e-1:
            for group in self.param_groups:
                group['epsilon'] /= 2.0
                group['epsilon'] = max(group['epsilon'], 0.0)
                return_eplison = group['epsilon'] 
        return return_eplison

    def get_grad_psi(self, x, grad_f, lmbda, kappa):
        grad_psi = grad_f + lmbda * x / (torch.norm(x, p=2, dim=1).unsqueeze(1) + 1e-6)
        return grad_psi

    def save_weight_and_grad_init_xs(self, num_batches):
        '''
            Revoked at the begining of the epoch
        '''
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['i'] = 0
                if 'hat_v' not in state.keys():
                    state['hat_v'] = torch.zeros_like(p.grad.data)
                state['hat_v'].copy_(p.grad.data)
                state['hat_v'].div_(num_batches)

                if 'hat_x' not in state.keys():
                    state['hat_x'] = torch.zeros_like(p.data)
                state['hat_x'].copy_(p.data)

                if 'xs_end' not in state.keys():
                    state['xs_end'] = torch.zeros_like(state['hat_x'])
                state['xs_end'].copy_(state['hat_x'])
                if 'xs_sum' not in state.keys():
                    state['xs_sum'] = torch.zeros_like(state['hat_x'])
                state['xs_sum'].zero_()

    def set_weights_from_xs(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                p.data.copy_(state['xs_end'])

    def save_grad_f(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'grad_f' not in state.keys():
                    state['grad_f'] = torch.zeros_like(p.grad.data)
                state['grad_f'].copy_(p.grad.data)

    def save_grad_f_hat(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'grad_f_hat' not in state.keys():
                    state['grad_f_hat'] = torch.zeros_like(p.grad.data)
                state['grad_f_hat'].copy_(p.grad.data)

    def set_weights_from_hat_x(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                p.data.copy_(state['hat_x'])

    def update_xs(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['i'] += 1
                v = state['grad_f'] - state['grad_f_hat'] + state['hat_v']
                if is_conv_weights(p.shape): # weights
                    delta = self.prox_mapping_group(state['xs_end'], v, group['lmbda'], group['lr'])
                    state['xs_end'].add_(delta)
                    state['xs_sum'].add_(state['xs_end'])
                else:
                    state['xs_end'].add_(-group['lr'], v)
                    state['xs_sum'].add_(state['xs_end'])

    def proxsvrg_step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad_f = p.grad.data
                state = self.state[p]

                p.data.copy_( state['xs_sum'] / state['i'] )
        return loss

    def prox_mapping_group_flatten(self, x, grad_f, lmbda, lr):
        trial_x = x - lr * grad_f
        numer = lr * lmbda
        denom = torch.norm(trial_x, p=2, dim=1)
        coeffs = 1.0 - numer / (denom + 1e-6)
        coeffs[coeffs <= 0] = 0.0
        coeffs = coeffs.unsqueeze(-1)
        trial_x = coeffs * trial_x
        return trial_x - x

    def prox_mapping_group(self, x, grad_f, lmbda, alpha):
        '''
            Proximal Mapping for next iterate for Omega(x) = sum_{g in G}||[x]_g||_2
        '''
        trial_x = x - alpha * grad_f
        delta = torch.zeros_like(x)
        num_kernels, channels, height, width = x.shape
        numer = alpha * lmbda
        denoms = torch.norm(trial_x.view(num_kernels, -1), p=2, dim=1)
        coeffs = 1.0 - numer / (denoms + 1e-6) 
        coeffs[coeffs<=0] = 0.0
        coeffs = coeffs.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        trial_x = coeffs * trial_x
        delta = trial_x - x
        return delta

    def enhanced_half_space_step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        ikpsg_size, ikhs_size = 0, 0
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if p.grad is None:
                    continue
                grad_f = p.grad.data
                x = p.data

                if is_conv_weight(x):
                    num_groups, c, h, w = x.shape
                    x = x.view(num_groups, -1)
                    grad_f = grad_f.view(num_groups, -1)
                    grad_psi = self.get_grad_psi(x, grad_f, group['lmbda'], group['kappa'])
                    prox_delta = self.prox_mapping_group_flatten(x, grad_f, group['lmbda'], group['lr'])

                    lhs = torch.norm(x, p=2, dim=1)
                    rhs = group['kappa'] * torch.norm(grad_psi, p=2, dim=1)
                    prox_iter_norm = torch.norm(x + prox_delta, p=2, dim=1)

                    indices_1 = lhs > 0 
                    indices_2 = lhs >= rhs
                    indices_3 = prox_iter_norm > 0
                    hs_indices = torch.logical_and(indices_1, indices_2)
                    hs_indices = torch.logical_and(hs_indices, indices_3)
                    psg_indices = torch.logical_not(hs_indices)
                    ikpsg_size += torch.sum(psg_indices)
                    ikhs_size += torch.sum(hs_indices)

                    # half-space step
                    trial_x = x[hs_indices, ...] - group['lr'] * grad_psi[hs_indices, ...]
                    # conduct half-space projection
                    proj_1 = torch.sum(trial_x * x[hs_indices, ...], dim=1) 
                    proj_2 = group['epsilon'] * torch.norm(x[hs_indices, ...], p=2, dim=1) ** 2
                    trial_x[proj_1 < proj_2, ...] = 0.0
                    p.data[hs_indices, ...] = trial_x.view(-1, c, h, w)
                    p.data[lhs == 0, ...] = 0.0
                else:
                    p.data = p.data - group['lr'] * grad_f
        return ikpsg_size.detach().cpu().numpy(), ikhs_size.detach().cpu().numpy()

    def switch(self, mu=1.0):
        print("switch begin")
        norm_psg, norm_hs = 0.0, 0.0
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if p.grad is None:
                    continue
                grad_f = p.grad.data
                x = p.data
                # print(x.shape)
                if is_conv_weight(x):
                    num_groups, c, h, w = x.shape
                    x = x.view(num_groups, -1)
                    grad_f = grad_f.view(num_groups, -1)
                    grad_psi = self.get_grad_psi(x, grad_f, group['lmbda'], group['kappa'])
                    prox_delta = self.prox_mapping_group_flatten(x, grad_f, group['lmbda'], group['lr'])
                    
                    lhs = torch.norm(x, p=2, dim=1)
                    rhs = group['kappa'] * torch.norm(grad_psi, p=2, dim=1)
                    prox_iter_norm = torch.norm(x + prox_delta, p=2, dim=1)

                    indices_1 = lhs > 0 
                    indices_2 = lhs >= rhs
                    indices_3 = prox_iter_norm > 0
                    hs_indices = torch.logical_and(indices_1, indices_2)
                    hs_indices = torch.logical_and(hs_indices, indices_3)
                    psg_indices = torch.logical_not(hs_indices)
                    norm_psg += torch.norm(prox_delta[psg_indices, ...], p=2)
                    norm_hs += torch.norm(prox_delta[hs_indices, ...], p=2)

                    # print(torch.sum(psg_indices), torch.sum(hs_indices))

        print("norm of prox_grad_psg", norm_psg, ", norm of prox_grad_hs", norm_hs)
        return True if norm_psg <= mu * norm_hs else False, norm_psg + norm_hs
    

def ParseArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument('--lmbda', default=1e-4, type=float, help='weighting parameters')
    parser.add_argument('--max_epoch', default=100, type=int)
    parser.add_argument('--epsilon', default=0.1, type=float) 
    parser.add_argument('--mu', default=1.0, type=float) 
    parser.add_argument('--lr', default=1e-1, type=float) 
    parser.add_argument('--kappa', default=1e-3, type=float) 
    parser.add_argument('--period', default=5, type=int) 
    parser.add_argument('--batch_size', default=128, type=int) 
    parser.add_argument('--backend', default='vgg16', type=str) # vgg16 | resnet18
    parser.add_argument('--dataset_name', default='cifar10', type=str) # cifar10 | mnist
    return parser.parse_args()

if __name__ == "__main__":
    
    args = ParseArgs()
    lmbda = args.lmbda
    epsilon = args.epsilon
    mu = args.mu
    kappa = args.kappa
    max_epoch = args.max_epoch
    period = args.period
    backend = args.backend
    dataset_name = args.dataset_name
    lr = args.lr
    batch_size = args.batch_size

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    trainloader, testloader = Dataset(dataset_name, batch_size=batch_size)
    model = Model(backend=backend, device=device)

    weights = [w for name, w in model.named_parameters() if "weight" in name]
    num_features = sum([w.numel() for w in weights])
    num_samples = len(trainloader) * trainloader.batch_size

    criterion = torch.nn.CrossEntropyLoss()

    optimizer = AdaHSPG(model.parameters(), lr=lr, lmbda=lmbda, epsilon=epsilon, kappa=kappa)

    print('Accuracy:', check_accuracy(model, testloader))
    
    os.makedirs('results', exist_ok=True)
    file_name = 'adahspg_group_%s_%s_%E_eps_%.2f_kappa_%.6f' % (backend, dataset_name, lmbda, epsilon, kappa)
    csvname = 'results/' + file_name + '.csv'
    print('The csv file is %s'%csvname)

    csvfile = open(csvname, 'w', newline='')
    fieldnames = ['epoch', 'F_value', 'f_value', 'omega_value', 'sparsity', 'sparsity_tol', 'sparsity_group', 'validation_acc', 'train_time', 'lr', 'eps', 'norm_gradmap', '|Ikpgs|', '|Ikhs|', '|G|', 'remarks']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter=",")
    writer.writeheader()

    F, f, omega = compute_func_values(trainloader, model, weights, criterion, lmbda)
    sparsity, sparsity_tol, sparsity_group, num_groups = compute_sparsity(weights)
    writer.writerow({'epoch': 0, 'F_value': F, 'f_value': f, 'omega_value': omega, 'sparsity': sparsity, \
        'sparsity_tol': sparsity_tol, 'sparsity_group': sparsity_group,\
        'validation_acc': check_accuracy(model, testloader), 'train_time': 'N/A', 'lr': 'N/A', 'eps': epsilon, 'norm_gradmap': 'N/A',\
        '|Ikpgs|': 'N/A', '|Ikhs|': 'N/A', '|G|': num_groups, 'remarks': '%s;%s;before optimization'%(backend, dataset_name)})
    csvfile.flush()

    alg_start_time = time.time()

    epoch = 0

    do_half_space = False

    global_step = 0
    prev_group_sparsity = 1
    while True:
        epoch_start_time = time.time()
        if epoch >= max_epoch:
            break
        
        norm_gradmap = None
        if epoch % period == 0:
            print("Switching mechanism...")
            model.train()
            optimizer.zero_grad()
            for index, (X, y) in enumerate(trainloader):
                X = X.to(device)
                y = y.to(device)
                y_pred = model.forward(X)
                f = criterion(y_pred, y)      
                f.backward()          
            optimizer.save_weight_and_grad_init_xs(len(trainloader))
            do_half_space, norm_gradmap = optimizer.switch(mu)
            print("Proceed proxsvrg step" if not do_half_space else "Proceed half-space step")
            
        optimizer.zero_grad()
        sum_ikpsg_size, sum_ikhs_size = 0, 0
        model.train()
        for index, (X, y) in enumerate(trainloader):
            X = X.to(device)
            y = y.to(device)
            
            if not do_half_space:
                optimizer.set_weights_from_xs()

            y_pred = model.forward(X)
            optimizer.zero_grad()
            
            f = criterion(y_pred, y)
            f.backward()
            if do_half_space:
                ikpsg_size, ikhs_size = optimizer.enhanced_half_space_step()
                sum_ikpsg_size += ikpsg_size
                sum_ikhs_size += ikhs_size
            else:
                optimizer.save_grad_f()
                optimizer.set_weights_from_hat_x()
                y_pred = model.forward(X)
                f = criterion(y_pred, y)
                optimizer.zero_grad()
                f.backward()
                optimizer.save_grad_f_hat()
                optimizer.update_xs()
        
        if not do_half_space:
            optimizer.proxsvrg_step()

        optimizer.adjust_learning_rate(epoch)
        epoch += 1
        train_time = time.time() - epoch_start_time
        F, f, omega = compute_func_values(trainloader, model, weights, criterion, lmbda)
        sparsity, sparsity_tol, sparsity_group, num_groups = compute_sparsity(weights)
        if do_half_space and epoch >= 150 and (epoch - 1) % period == 0 and norm_gradmap is not None:
            epsilon = optimizer.adjust_epsilon(sparsity_group, prev_group_sparsity, norm_gradmap)
        prev_group_sparsity = sparsity_group
        accuracy = check_accuracy(model, testloader)
        writer.writerow({'epoch': epoch, 'F_value': F, 'f_value': f, 'omega_value': omega, 'sparsity': sparsity, \
            'sparsity_tol': sparsity_tol, 'sparsity_group': sparsity_group,\
            'validation_acc': accuracy, 'train_time': train_time, 'lr': optimizer.param_groups[0]['lr'], \
            'eps': epsilon, 'norm_gradmap': norm_gradmap.detach().cpu().numpy() if norm_gradmap is not None else 'N/A', \
            '|Ikpgs|': sum_ikpsg_size / float(len(trainloader)), '|Ikhs|': sum_ikhs_size / float(len(trainloader)), '|G|': num_groups, \
            'remarks': '%s;%s;%E;%s'%(backend, dataset_name, lmbda, "half-space" if do_half_space else "proxsg")})
        csvfile.flush()
        print(". Epoch time: {:2f}seconds ...".format(train_time))

    alg_time = time.time() - alg_start_time
    writer.writerow({'train_time': alg_time / epoch})
    # save model at 
    torch.save(model, 'models/' + file_name+'.pt')
    csvfile.close()

